#########################################################################################################.
###                                                                                                   ###
###     Rage Against the Machine? Generative AI Exposure, Subjective Risk, and Policy Preferences     ###
###     Journal of European Public Policy                                                             ###
###                                                                                                   ###
###     Haslberger, Gingrich & Bhatia                                                                 ###
###                                                                                                   ###
###     Main Analysis                                                                                 ###
###                                                                                                   ###
#########################################################################################################.

# Setup ----

rm(list = ls())

setwd("your_working_directory") #set working directory

pacman::p_load(knitr, tidyverse, haven, validate, expss,
               dataMaid, sjlabelled, cobalt, stable, crosstable, gmodels, irr, readxl, ggplot,
               ggeffects, grid, gridtext, gridExtra, corpus, lubridate, reshape2, tm, 
               quanteda, rsparse, stringi, readtext, SnowballC, 
               wordcloud, RColorBrewer, tidytext, widyr, irlba, 
               randomcoloR, stm, mediocrethemes, ggpubr,quanteda.textplots,quanteda.textstats,
               sentimentr, magrittr, lexicon, kableExtra, ggrepel, keyATM, 
               vtable, labelled, text2vec, flextable, AER, topicmodels, ggtext, stargazer) 

source("00_helperfunctions.r")


data <- read_rds("data.rds")
textdata <- read_rds("textdata.rds")

# Main Analysis ----

## Figure 2 ----

model_list <- list(
  t1.1 = lm(data = data, replace_skill2 ~ treatment, weights = weight),
  t1.2 = lm(data = data, replace_ai2 ~ treatment, weights = weight),
  t1.3 = lm(data = data, ai_surveillance2 ~ treatment, weights = weight),
  t1.4 = lm(data = data, societal_benefit2 ~ treatment, weights = weight),
  t1.5 = lm(data = data, personal_benefit2 ~ treatment, weights = weight),
  t1.6 = lm(data = data, gov_limit2 ~ treatment, weights = weight),
  t1.7 = lm(data = data, ai_threat2 ~ treatment, weights = weight)
)

# Extract information from all models
results_data <- t(sapply(model_list, extract_model_info))
results <- data.frame(
  Model = rownames(results_data),
  Coefficient = results_data[,1],
  Lower95 = results_data[,2],
  Upper95 = results_data[,3],
  Lower90 = results_data[,4],
  Upper90 = results_data[,5],
  Term = rep(c("Lose job: skilled human", "Lose job: technology", "AI will increase\n workplace surveillance",
               "AI has more benefits\n than drawbacks", "People like me are likely\n to benefit from AI", 
               "Government should\n limit the use of AI", "AI poses a serious\n threat to humanity"), times = 1)
)

# Define the sequence to be repeated
results$Term <- factor(results$Term, levels = c('AI poses a serious\n threat to humanity', 'Government should\n limit the use of AI', 
                                                'AI has more benefits\n than drawbacks', 'People like me are likely\n to benefit from AI', 
                                                'AI will increase\n workplace surveillance', 'Lose job: technology', 'Lose job: skilled human'))

# Create the coefficient plot
p1 <- ggplot(results, aes(x = Coefficient, y = Term)) +
  geom_vline(aes(xintercept = 0), linetype = "solid", color = "black") +
  geom_point(position = position_dodge(width = dodge_width), size = 4, color = onecolour) +
  geom_linerange(aes(xmin = Lower95, xmax = Upper95),
                 position = position_dodge(width = dodge_width), linewidth = 1, color = onecolour) +
  geom_linerange(aes(xmin = Lower90, xmax = Upper90),
                 position = position_dodge(width = dodge_width), linewidth = 2, color = onecolour) +
  scale_color_manual(values = onecolour) +
  scale_fill_manual(values = onecolour) +
  theme_minimal() +
  theme(axis.title.y = element_blank(), title = element_text(size = 25), 
        axis.title=element_text(size = 16), axis.text.y = element_text(size = 16), 
        axis.text.x = element_text(size = 16), legend.position = "") + 
  labs(x = "Difference in share `likely/agree`", title = "") 

ggsave("output/figure_2.jpg",
       p1,
       width = 300, height = 150, units = "mm")


## Figure 3 ---- 

model_list <- list(
  t1.1 = lm(data = data, resp_training2 ~ treatment, weights = weight),
  t1.2 = lm(data = data, resp_job2 ~ treatment, weights = weight),
  t1.3 = lm(data = data, resp_benefits2 ~ treatment, weights = weight)
)

# Extract information from all models
results_data <- t(sapply(model_list, extract_model_info))
results <- data.frame(
  Model = rownames(results_data),
  Coefficient = results_data[,1],
  Lower95 = results_data[,2],
  Upper95 = results_data[,3],
  Lower90 = results_data[,4],
  Upper90 = results_data[,5],
  Term = rep(c("Training opportunities for everyone", "A job for everyone who wants one", "Decent UE benefits"), times = 1)
)

results$Term <- factor(results$Term, levels = c('Training opportunities for everyone', 'A job for everyone who wants one', 'Decent UE benefits'))                        

p1 <- ggplot(results, aes(x = Coefficient, y = Term)) +
  geom_vline(aes(xintercept = 0), linetype = "solid", color = "black") +
  geom_point(position = position_dodge(width = dodge_width), size = 4, color = onecolour) +
  geom_linerange(aes(xmin = Lower95, xmax = Upper95),
                 position = position_dodge(width = dodge_width), linewidth = 1, color = onecolour) +
  geom_linerange(aes(xmin = Lower90, xmax = Upper90),
                 position = position_dodge(width = dodge_width), linewidth = 2, color = onecolour) +
  #  scale_color_manual(values = onecolour) +
  # scale_fill_manual(values = onecolour) +
  theme_minimal() +
  theme(axis.title.y = element_blank(), title = element_text(size = 16), 
        axis.title=element_text(size = 16), axis.text.y = element_text(size = 16), 
        axis.text.x = element_text(size = 16), legend.position = "") + 
  labs(x = "Difference in share `probably/definitely should be`", title = "Should it be government responsibility to provide...?") 

ggsave("output/figure_3.jpg",
       p1,
       width = 300, height = 100, units = "mm")

## Figures 4, 5, F1, F2 Preparation ----

# Here we create the heterogeneous effects plots for sex, occupation, age, and education in one go.
# Below, we create the final figures.

### Sex

data$ntreatment <- relevel(factor(data$ntreatment), ref = "0")
data$female <- relevel(factor(data$female), ref = "0")
model_list_sex <- list(
  a1 = lm(data = data, ai_surveillance2 ~ ntreatment*female, weights = weight),
  a5 = lm(data = data, ai_threat2 ~ ntreatment*female, weights = weight), 
  a4 = lm(data = data, gov_limit2 ~ ntreatment*female, weights = weight),
  a3 = lm(data = data, societal_benefit2 ~ ntreatment*female, weights = weight),
  a2 = lm(data = data, personal_benefit2 ~ ntreatment*female, weights = weight),
  r2 = lm(data = data, replace_ai2 ~ ntreatment*female, weights = weight),
  r1 = lm(data = data, replace_skill2 ~ ntreatment*female, weights = weight),
  s1 = lm(data = data, resp_training2 ~ ntreatment*female, weights = weight),
  s2 = lm(data = data, resp_job2 ~ ntreatment*female, weights = weight),
  s3 = lm(data = data, resp_benefits2 ~ ntreatment*female, weights = weight)
)

marginal_results_sex <- lapply(names(model_list_sex), function(mn) {
  tmp <- extract_marginal_effects(model_list_sex[[mn]], "ntreatment", "female")
  tmp$model_name <- mn
  tmp$interaction <- "sex"
  tmp
}) %>%
  bind_rows() %>%
  mutate(outcome = case_when(model_name == "r1" ~ "Lose job: skilled human",
                             model_name == "r2" ~ "Lose job: technology",
                             model_name == "a1" ~ "Surveillance",
                             model_name == "a2" ~ "Own benefit", 
                             model_name == "a3" ~ "Societal benefit",
                             model_name == "a4" ~ "Gov. limitations",
                             model_name == "a5" ~ "AI threat",
                             model_name == "s1" ~ "Training",
                             model_name == "s2" ~ "Job guarantee",
                             model_name == "s3" ~ "UE benefits"))

### Occupation

data$prof_man <- relevel(factor(data$prof_man), ref = "0")
model_list_occ <- list(
  a1 = lm(data = data, ai_surveillance2 ~ ntreatment*prof_man, weights = weight),
  a5 = lm(data = data, ai_threat2 ~ ntreatment*prof_man, weights = weight), # + degree + age + pers_income
  a4 = lm(data = data, gov_limit2 ~ ntreatment*prof_man, weights = weight),
  a3 = lm(data = data, societal_benefit2 ~ ntreatment*prof_man, weights = weight),
  a2 = lm(data = data, personal_benefit2 ~ ntreatment*prof_man, weights = weight),
  r2 = lm(data = data, replace_ai2 ~ ntreatment*prof_man, weights = weight),
  r1 = lm(data = data, replace_skill2 ~ ntreatment*prof_man, weights = weight),
  s1 = lm(data = data, resp_training2 ~ ntreatment*prof_man, weights = weight),
  s2 = lm(data = data, resp_job2 ~ ntreatment*prof_man, weights = weight),
  s3 = lm(data = data, resp_benefits2 ~ ntreatment*prof_man, weights = weight)
)

marginal_results_occ <- lapply(names(model_list_occ), function(mn) {
  tmp <- extract_marginal_effects(model_list_occ[[mn]], "ntreatment", "prof_man")
  tmp$model_name <- mn
  tmp$interaction <- "occ"
  tmp
}) %>%
  bind_rows() %>%
  mutate(outcome = case_when(model_name == "r1" ~ "Lose job: skilled human",
                             model_name == "r2" ~ "Lose job: technology",
                             model_name == "a1" ~ "Surveillance",
                             model_name == "a2" ~ "Own benefit", 
                             model_name == "a3" ~ "Societal benefit",
                             model_name == "a4" ~ "Gov. limitations",
                             model_name == "a5" ~ "AI threat",
                             model_name == "s1" ~ "Training",
                             model_name == "s2" ~ "Job guarantee",
                             model_name == "s3" ~ "UE benefits"))


### Age

data$under40 <- relevel(factor(data$under40), ref = "0")
model_list_age <- list(
  a1 = lm(data = data, ai_surveillance2 ~ ntreatment*under40, weights = weight),
  a5 = lm(data = data, ai_threat2 ~ ntreatment*under40, weights = weight), # + degree + age + pers_income
  a4 = lm(data = data, gov_limit2 ~ ntreatment*under40, weights = weight),
  a3 = lm(data = data, societal_benefit2 ~ ntreatment*under40, weights = weight),
  a2 = lm(data = data, personal_benefit2 ~ ntreatment*under40, weights = weight),
  r2 = lm(data = data, replace_ai2 ~ ntreatment*under40, weights = weight),
  r1 = lm(data = data, replace_skill2 ~ ntreatment*under40, weights = weight),
  s1 = lm(data = data, resp_training2 ~ ntreatment*under40, weights = weight),
  s2 = lm(data = data, resp_job2 ~ ntreatment*under40, weights = weight),
  s3 = lm(data = data, resp_benefits2 ~ ntreatment*under40, weights = weight)
)

marginal_results_age <- lapply(names(model_list_age), function(mn) {
  tmp <- extract_marginal_effects(model_list_age[[mn]], "ntreatment", "under40")
  tmp$model_name <- mn
  tmp$interaction <- "age"
  tmp
}) %>%
  bind_rows() %>%
  mutate(outcome = case_when(model_name == "r1" ~ "Lose job: skilled human",
                             model_name == "r2" ~ "Lose job: technology",
                             model_name == "a1" ~ "Surveillance",
                             model_name == "a2" ~ "Own benefit", 
                             model_name == "a3" ~ "Societal benefit",
                             model_name == "a4" ~ "Gov. limitations",
                             model_name == "a5" ~ "AI threat",
                             model_name == "s1" ~ "Training",
                             model_name == "s2" ~ "Job guarantee",
                             model_name == "s3" ~ "UE benefits"))


### Education

data$ntreatment <- relevel(factor(data$ntreatment), ref = "0")
data$degree <- relevel(factor(data$degree), ref = "0")
model_list_edu <- list(
  a1 = lm(data = data, ai_surveillance2 ~ ntreatment*degree, weights = weight),
  a5 = lm(data = data, ai_threat2 ~ ntreatment*degree, weights = weight), # + degree + age + pers_income
  a4 = lm(data = data, gov_limit2 ~ ntreatment*degree, weights = weight),
  a3 = lm(data = data, societal_benefit2 ~ ntreatment*degree, weights = weight),
  a2 = lm(data = data, personal_benefit2 ~ ntreatment*degree, weights = weight),
  r2 = lm(data = data, replace_ai2 ~ ntreatment*degree, weights = weight),
  r1 = lm(data = data, replace_skill2 ~ ntreatment*degree, weights = weight),
  s1 = lm(data = data, resp_training2 ~ ntreatment*degree, weights = weight),
  s2 = lm(data = data, resp_job2 ~ ntreatment*degree, weights = weight),
  s3 = lm(data = data, resp_benefits2 ~ ntreatment*degree, weights = weight)
)

marginal_results_edu <- lapply(names(model_list_edu), function(mn) {
  tmp <- extract_marginal_effects(model_list_edu[[mn]], "ntreatment", "degree")
  tmp$model_name <- mn
  tmp$interaction <- "edu"
  tmp
}) %>%
  bind_rows() %>%
  mutate(outcome = case_when(model_name == "r1" ~ "Lose job: skilled human",
                             model_name == "r2" ~ "Lose job: technology",
                             model_name == "a1" ~ "Surveillance",
                             model_name == "a2" ~ "Own benefit", 
                             model_name == "a3" ~ "Societal benefit",
                             model_name == "a4" ~ "Gov. limitations",
                             model_name == "a5" ~ "AI threat",
                             model_name == "s1" ~ "Training",
                             model_name == "s2" ~ "Job guarantee",
                             model_name == "s3" ~ "UE benefits"))


## Combined Figures 4, 5, F1, F2 ----
marginal_results <- marginal_results_age %>%
  rbind(marginal_results_sex,marginal_results_edu,marginal_results_occ) 


# 1. Identify unique (model_name, interaction) combos:
combos <- marginal_results %>%
  distinct(model_name, interaction)

# 2. Prepare a list to store plots
plot_list <- list()

# 3. Loop over each combination
for (i in seq_len(nrow(combos))) {
  
  mn    <- combos$model_name[i]
  inter <- combos$interaction[i]
  
  # Subset data for this pair
  df_sub <- marginal_results %>%
    filter(model_name == mn, interaction == inter)
  
  # ----- (1) Title from df_sub$task 
  # (Assuming there's exactly one unique 'task' per subset)
  #task_title <- unique(df_sub$task)
  
  # ----- (2) X-axis label from df_sub$outcome 
  # (Similarly, assuming one unique 'outcome' per subset)
  outcome_title <- unique(df_sub$outcome)
  
  # Create a factor combining X_level and Z_level (if needed)
  df_sub <- df_sub %>%
    mutate(xz_combo = factor(paste0("X=", X_level, ",Z=", Z_level)))
  
  p <- ggplot(df_sub, aes(
    y    = xz_combo,        # combos on the vertical axis
    x    = difference,      # difference from baseline horizontally
    shape = xz_combo        # each combo a different shape
  )) +
    # Point estimates
    geom_point(size = 5, color = onecolour) +
    # 95% CIs
    geom_linerange(aes(xmin = ci95_lower, xmax = ci95_upper), linewidth = 1, color = onecolour) +
    # 90% CIs
    geom_linerange(aes(xmin = ci90_lower, xmax = ci90_upper), linewidth = 2, color = onecolour) +
    # Vertical zero line
    geom_vline(xintercept = 0, linetype = "dashed") +
    
    # Shapes for each (X,Z) combination
    scale_shape_manual(values = c(
      "X=1,Z=0" = 16,  # square
      "X=0,Z=1" = 15,  # circle
      "X=1,Z=1" = 17   # triangle
    )) +
    
    # ----- (2) X-axis label from 'outcome' and (1) title from 'task' 
    labs(x = "Marginal effect",
         y = NULL,
         title = outcome_title) +
    
    theme_minimal() +
    theme(
      legend.position = "none",
      # ----- (3) Remove y-axis ticks & text 
      axis.text.y = element_blank(),
      axis.ticks.y = element_blank(),
      axis.text.x = element_text(size = 14),
      axis.title.x = element_text(size = 14),
      title = element_text(size = 16)
    )
  
  # Save into the list
  plot_list[[paste0(mn, "_", inter)]] <- p
}


# Create titles
title_list <- list()
title_list[["risk"]] <- make_title_grob(title_text = "A: Subjective Risk", font_size = 18)
title_list[["attitudes"]] <- make_title_grob(title_text = "B: Attitudes", font_size = 18)
title_list[["policy"]] <- make_title_grob(title_text = "C: Social Policy", font_size = 18)


# Create legends
legend_list <- list()
sex_shapes <- c("Female, control" = 15, "Male, treated" = 16, "Female, treated" = 17)
legend_list[["sex"]] <- make_legend_grob(shape_mapping = sex_shapes)
age_shapes <- c("Under 40, control" = 15, "40+, treated" = 16, "Under 40, treated" = 17)
legend_list[["age"]] <- make_legend_grob(shape_mapping = age_shapes)
degree_shapes <- c("Degree, control" = 15, "No degree, treated" = 16, "Degree, treated" = 17)
legend_list[["edu"]] <- make_legend_grob(shape_mapping = degree_shapes)
occ_shapes <- c("Professional, control" = 15, "Non-professional, treated" = 16, "Professional, treated" = 17)
legend_list[["occ"]] <- make_legend_grob(shape_mapping = occ_shapes)
freq_shapes <- c("Frequent, control" = 15, "Infrequent, treated" = 16, "Frequent, treated" = 17)
legend_list[["freq"]] <- make_legend_grob(shape_mapping = freq_shapes)


# Layout matrix for combined figures
het_matrix <- matrix(
  c(1,1,1,
    4,5,15,
    15,15,15,
    2,2,2,
    6,7,8,
    15,15,15,
    9,10,15,
    15,15,15,
    3,3,3,
    11,12,13,
    14,14,14),
  nrow = 11, ncol = 3, byrow = T)

spacer <- grid::nullGrob()

### Figure 4 ----
grobs_sex <- list(title_list[["risk"]],
                  title_list[["attitudes"]],
                  title_list[["policy"]],
                  plot_list[["r1_sex"]],
                  plot_list[["r2_sex"]],
                  plot_list[["a1_sex"]],
                  plot_list[["a2_sex"]],
                  plot_list[["a3_sex"]],
                  plot_list[["a4_sex"]],
                  plot_list[["a5_sex"]],
                  plot_list[["s1_sex"]],
                  plot_list[["s2_sex"]],
                  plot_list[["s3_sex"]],
                  legend_list[["sex"]],
                  spacer)


fig_sex <- grid.arrange(grobs = grobs_sex,
                        layout_matrix = het_matrix,
                        heights = c (0.3,1,0.2,0.3,1,0.2,1,0.2,0.3,1,0.3))
ggsave("output/figure_4.jpg",
       fig_sex,
       width = 250, height = 270, units = "mm")


### Figure 5 ----
grobs_occ <- list(title_list[["risk"]],
                  title_list[["attitudes"]],
                  title_list[["policy"]],
                  plot_list[["r1_occ"]],
                  plot_list[["r2_occ"]],
                  plot_list[["a1_occ"]],
                  plot_list[["a2_occ"]],
                  plot_list[["a3_occ"]],
                  plot_list[["a4_occ"]],
                  plot_list[["a5_occ"]],
                  plot_list[["s1_occ"]],
                  plot_list[["s2_occ"]],
                  plot_list[["s3_occ"]],
                  legend_list[["occ"]],
                  spacer)


fig_occ <- grid.arrange(grobs = grobs_occ,
                        layout_matrix = het_matrix,
                        heights = c (0.3,1,0.2,0.3,1,0.2,1,0.2,0.3,1,0.3))
ggsave("output/figure_5.jpg",
       fig_occ,
       width = 250, height = 270, units = "mm")


### Figure F1 ----
grobs_age <- list(title_list[["risk"]],
                  title_list[["attitudes"]],
                  title_list[["policy"]],
                  plot_list[["r1_age"]],
                  plot_list[["r2_age"]],
                  plot_list[["a1_age"]],
                  plot_list[["a2_age"]],
                  plot_list[["a3_age"]],
                  plot_list[["a4_age"]],
                  plot_list[["a5_age"]],
                  plot_list[["s1_age"]],
                  plot_list[["s2_age"]],
                  plot_list[["s3_age"]],
                  legend_list[["age"]],
                  spacer)


fig_age <- grid.arrange(grobs = grobs_age,
                        layout_matrix = het_matrix,
                        heights = c (0.3,1,0.2,0.3,1,0.2,1,0.2,0.3,1,0.3))
ggsave("output/figure_f1.jpg",
       fig_age,
       width = 250, height = 270, units = "mm")



### Figure F2 ----
grobs_edu <- list(title_list[["risk"]],
                  title_list[["attitudes"]],
                  title_list[["policy"]],
                  plot_list[["r1_edu"]],
                  plot_list[["r2_edu"]],
                  plot_list[["a1_edu"]],
                  plot_list[["a2_edu"]],
                  plot_list[["a3_edu"]],
                  plot_list[["a4_edu"]],
                  plot_list[["a5_edu"]],
                  plot_list[["s1_edu"]],
                  plot_list[["s2_edu"]],
                  plot_list[["s3_edu"]],
                  legend_list[["edu"]],
                  spacer)


fig_edu <- grid.arrange(grobs = grobs_edu,
                        layout_matrix = het_matrix,
                        heights = c (0.3,1,0.2,0.3,1,0.2,1,0.2,0.3,1,0.3))
ggsave("output/figure_f2.jpg",
       fig_edu,
       width = 250, height = 270, units = "mm")

## Figure 5 ----

# see "Combined Figures" above

## Figure 6 ----

sent <- ggplot(textdata, aes(x = sentiment, fill = treatment)) +
  geom_density(alpha = 0.5) +
  geom_vline(xintercept = 0.2120954, linetype = "dashed", color = "#004488") +
  geom_vline(xintercept = 0.1950732, linetype = "dashed", color = "#BB5566") +
  labs(title = "",
       x = "Sentiment Score",
       y = "Density",
       fill = "Group") +
  scale_fill_manual(values = twocolours) +
  theme_minimal() +
  theme(legend.title = element_text(size = 18, face = "plain"), legend.text = element_text(size = 18),
        axis.text = element_text(size = 18, color = "black"), axis.title = element_text(size = 18))

ggsave("output/figure_6.jpg",
       sent,
       width = 300, height = 120, units = "mm")

## Figure 7 ----

answers <-  textdata %>%
  dplyr::select(id, conclusion, ntreatment, female, degree, age, prof_man, overall_score_gpt4) 

processed <- textProcessor(answers$conclusion, metadata= answers)
out <- prepDocuments(processed$documents, processed$vocab, processed$meta, lower.thresh = 5)
docs <- out$documents
vocab <- out$vocab
meta <- out$meta

# filter out NAs for degree --> use this as consistent base model
valid_indices <- !is.na(meta$degree)  # Indices of documents with non-NA 'degree' values
filtered_docs <- docs[valid_indices]
filtered_meta <- meta[valid_indices, ]

#Model with only treatment as covariate
fit_treat <- stm(documents = filtered_docs, vocab = vocab, K = 3, 
                 prevalence = ~ ntreatment, max.em.its = 300, data = filtered_meta, 
                 init.type = "Spectral")
labelTopics(fit_treat)
plot(fit_treat, type = "summary", main = "", n = 5)


# Extract the expected topic proportions (average across documents)
topic_props <- data.frame(
  topic = 1:fit_treat$settings$dim$K,
  proportion = colMeans(fit_treat$theta)
)

# Get top 5 words for each topic and create a label
topic_labels <- apply(labelTopics(fit_treat, n = 5)$prob, 1, 
                      function(words) paste(words, collapse = ", "))
topic_props$label <- topic_labels

# Create a new factor variable where topics are ordered naturally and then reversed for plotting:
topic_props$TopicLabel <- factor(topic_props$topic, levels = c(1, 2, 3), 
                                 labels = c("Intervention", "Safety and Ethics", "Limitations"))
#labels = paste("Topic", c(1, 2, 3)))
# Reverse the levels so that Topic 1 appears at the top when plotted
topic_props$TopicLabel <- fct_rev(topic_props$TopicLabel)

stm1 <- ggplot(topic_props, aes(x = TopicLabel, y = proportion, fill = factor(topic))) +
  geom_bar(stat = "identity", width = 0.6, alpha = 0.7) +
  geom_text(aes(label = label, y = 0.01), hjust = 0, color = "black", size = 7.5) +
  coord_flip() +
  labs(x = "", y = "Expected Proportion", title = "") +
  scale_fill_manual(values = threecolours) +
  theme_minimal() +
  theme(legend.position = "none", 
        axis.title.x = element_text(size = 20, colour = "black"),
        axis.text.y = element_text(size = 20, colour = "black"),
        axis.text.x = element_text(size = 16, colour = "black"))

stm1_title <- make_title_grob(title_text = "A: Topic Prevalence", font_size = 22)

stm1_matrix <- matrix(
  c(1,
    2,
    3),
  nrow = 3, ncol = 1, byrow = T)

spacer <- grid::nullGrob()
grobs_stm1 <- list(stm1_title,stm1,spacer)

fig_stm1 <- grid.arrange(grobs = grobs_stm1,
                         layout_matrix = stm1_matrix,
                         heights = c (0.12,1,0.15))

ggsave("output/figure_7a.jpg", fig_stm1, width = 300, height = 150, units = "mm")


### Panel B

#estimate model
set.seed(12345)
prep_treat <- estimateEffect(1:3 ~ ntreatment, fit_treat,  meta = filtered_meta, uncertainty = "Global")

# Get the summary tables for each topic
effect_tables <- summary(prep_treat)$tables

term_labels <- c("Intervention", "Safety and Ethics", "Limitations")

# Combine the tables into a single data frame (assuming one row per topic for 'ntreatment')
results <- do.call(rbind, lapply(seq_along(effect_tables), function(i) {
  coef_row <- effect_tables[[i]]["ntreatment", ]
  est <- coef_row["Estimate"]
  se <- coef_row["Std. Error"]
  data.frame(
    Term = term_labels[i],
    Coefficient = est,
    se = se,
    Lower95 = est - 1.96 * se,
    Upper95 = est + 1.96 * se,
    Lower90 = est - 1.645 * se,
    Upper90 = est + 1.645 * se
  )
}))

results$Term <- factor(results$Term, levels = term_labels)

# View the resulting data frame
print(results)


stm2 <- ggplot(results, aes(x = Coefficient, y = fct_rev(Term), color = Term, fill = Term)) +
  geom_vline(xintercept = 0, linetype = "solid", color = "black") +
  geom_point(position = position_dodge(width = dodge_width), size = 4) +
  geom_linerange(aes(xmin = Lower95, xmax = Upper95),
                 position = position_dodge(width = dodge_width), linewidth = 1) +
  geom_linerange(aes(xmin = Lower90, xmax = Upper90),
                 position = position_dodge(width = dodge_width), linewidth = 2) +
  scale_color_manual(values = threecolours) +
  scale_fill_manual(values = threecolours) +
  theme_minimal() +
  theme(axis.title.y = element_blank(),
        axis.title = element_text(size = 20), 
        axis.text.y = element_text(size = 20, colour = "black"), 
        axis.text.x = element_text(size = 16, colour = "black"), 
        legend.position = "none") +
  labs(x = "Treatment effect on topic prevalence", title = "")

stm2_title <- make_title_grob(title_text = "B: Treatment Effects", font_size = 22)

stm2_matrix <- matrix(
  c(1,
    2,
    3),
  nrow = 3, ncol = 1, byrow = T)

grobs_stm2 <- list(stm2_title,stm2,spacer)

fig_stm2 <- grid.arrange(grobs = grobs_stm2,
                         layout_matrix = stm2_matrix,
                         heights = c (0.12,1,0.15))

ggsave("output/figure_7b.jpg", fig_stm2, height = 150, width = 300, units = "mm")



# Appendix ----

## Table B1 ----

crosstable(data, c(age, female, degree, prof_man, hh_income, skills_avg), by = treatment, margin = "column") %>% as_flextable(keep_id=FALSE) 

## Table B2 ----

covs <- subset(data, select = c(age,female,degree,prof_man,hh_income,education_vshort_lab,skills_avg))
bal.tab(covs, treat = data$treatment, thresholds = c(m = .1))

## Figure C1 ----

pt1 <- plot_canvas_simtask(df = data, q = data$task1similar_s, cat = data$treatment, labs = "Email task")
pt2 <- plot_canvas_simtask(df = data, q = data$task2similar_s, cat = data$treatment, labs = "Assessment task")
pt3 <- plot_canvas_simtask(df = data, q = data$task3similar_s, cat = data$treatment, labs = "Comprehension task")

title_grob <- textGrob("Frequency of tasks similar to...", gp = gpar(fontsize = 24, fontface = "bold"))
legend_grob <- richtext_grob(
  'By group: <span style="color:#004488"><b>control</b></span> and <span style="color:#DDAA33"><b>treatment</b></span>', 
  gp = gpar(fontsize = 20))

pt.all <- gridExtra::grid.arrange(pt1, pt2, pt3,  ncol=3, top = title_grob, bottom = legend_grob)
ggsave("output/figure_c1.jpg",
       pt.all,
       width = 350, height = 175, units = "mm")

## Figure C2 ----

pt1 <- plot_canvas_useful(df = data, q = data$ai_helpful1_n, cat = data$treatment, labs = "Email task")
pt2 <- plot_canvas_useful(df = data, q = data$ai_helpful2_n, cat = data$treatment, labs = "Assessment task")
pt3 <- plot_canvas_useful(df = data, q = data$ai_helpful3_n, cat = data$treatment, labs = "Comprehension task")

title_grob <- textGrob("Usefulness of AI in...", gp = gpar(fontsize = 24, fontface = "bold"))
legend_grob <- richtext_grob(
  'By group: <span style="color:#004488"><b>control</b></span> and <span style="color:#DDAA33"><b>treatment</b></span>', 
  gp = gpar(fontsize = 20))

pt.all <- gridExtra::grid.arrange(pt1, pt2, pt3,  ncol=3, top = title_grob, bottom = legend_grob)
ggsave("output/figure_c2.jpg",
       pt.all,
       width = 350, height = 175, units = "mm")

## Figure C13 ----

# Performance data are under embargo due to their use in an as yet (08/25) unpublished paper. 
# Please consult the corresponding author (Matthias Haslberger) for access to the data and code.

## Table D1 ----
t1 = lm(data = data, replace_skill2 ~ treatment, weights = weight)
t2 = lm(data = data, replace_ai2 ~ treatment, weights = weight)
t3 = lm(data = data, ai_surveillance2 ~ treatment, weights = weight)
t4 = lm(data = data, personal_benefit2 ~ treatment, weights = weight)
t5 = lm(data = data, societal_benefit2 ~ treatment, weights = weight)
t6 = lm(data = data, gov_limit2 ~ treatment, weights = weight)
t7 = lm(data = data, ai_threat2 ~ treatment, weights = weight)

stargazer(t1,t2,t3,t4,t5,t6,t7, type = "latex")

## Table D2 ----
t1 = lm(data = data, resp_training2 ~ treatment, weights = weight)
t2 = lm(data = data, resp_job2 ~ treatment, weights = weight)
t3 = lm(data = data, resp_benefits2 ~ treatment, weights = weight)

stargazer(t1,t2,t3, type = "latex")

## Table D3 ----
r1 = lm(data = data, replace_skill2 ~ ntreatment*female, weights = weight)
r2 = lm(data = data, replace_ai2 ~ ntreatment*female, weights = weight)
a1 = lm(data = data, ai_surveillance2 ~ ntreatment*female, weights = weight)
a2 = lm(data = data, personal_benefit2 ~ ntreatment*female, weights = weight)
a3 = lm(data = data, societal_benefit2 ~ ntreatment*female, weights = weight)
a4 = lm(data = data, gov_limit2 ~ ntreatment*female, weights = weight)
a5 = lm(data = data, ai_threat2 ~ ntreatment*female, weights = weight)

stargazer(r1,r2,a1,a2,a3,a4,a5, type = "latex")

## Table D4 ----
s1 = lm(data = data, resp_training2 ~ ntreatment*female, weights = weight)
s2 = lm(data = data, resp_job2 ~ ntreatment*female, weights = weight)
s3 = lm(data = data, resp_benefits2 ~ ntreatment*female, weights = weight)

stargazer(s1,s2,s3, type = "latex")

## Table D5 ----
r1 = lm(data = data, replace_skill2 ~ ntreatment*prof_man, weights = weight)
r2 = lm(data = data, replace_ai2 ~ ntreatment*prof_man, weights = weight)
a1 = lm(data = data, ai_surveillance2 ~ ntreatment*prof_man, weights = weight)
a2 = lm(data = data, personal_benefit2 ~ ntreatment*prof_man, weights = weight)
a3 = lm(data = data, societal_benefit2 ~ ntreatment*prof_man, weights = weight)
a4 = lm(data = data, gov_limit2 ~ ntreatment*prof_man, weights = weight)
a5 = lm(data = data, ai_threat2 ~ ntreatment*prof_man, weights = weight)

stargazer(r1,r2,a1,a2,a3,a4,a5, type = "latex")

## Table D6 ----
s1 = lm(data = data, resp_training2 ~ ntreatment*prof_man, weights = weight)
s2 = lm(data = data, resp_job2 ~ ntreatment*prof_man, weights = weight)
s3 = lm(data = data, resp_benefits2 ~ ntreatment*prof_man, weights = weight)

stargazer(s1,s2,s3, type = "latex")

## Table D7 ----

answers <-  textdata %>%
  dplyr::select(id, conclusion, ntreatment, female, degree, age, prof_man, overall_score_gpt4) 

processed <- textProcessor(answers$conclusion, metadata= answers)
out <- prepDocuments(processed$documents, processed$vocab, processed$meta, lower.thresh = 5)
docs <- out$documents
vocab <- out$vocab
meta <- out$meta

# filter out NAs for degree --> use this as consistent base model
valid_indices <- !is.na(meta$degree)  # Indices of documents with non-NA 'degree' values
filtered_docs <- docs[valid_indices]
filtered_meta <- meta[valid_indices, ]

fit_treat <- stm(documents = filtered_docs, vocab = vocab, K = 3, 
                 prevalence = ~ ntreatment, max.em.its = 300, data = filtered_meta, 
                 init.type = "Spectral")
set.seed(12345)
prep_treat <- estimateEffect(1:3 ~ ntreatment, fit_treat,  meta = filtered_meta, uncertainty = "Global")
summary(prep_treat)

## Table E2 ----

answers <-  textdata %>%
  dplyr::select(id, conclusion, ntreatment, female, degree, age, prof_man, overall_score_gpt4) 

processed <- textProcessor(answers$conclusion, metadata= answers)
out <- prepDocuments(processed$documents, processed$vocab, processed$meta, lower.thresh = 5)
docs <- out$documents
vocab <- out$vocab
meta <- out$meta

# filter out NAs for degree --> use this as consistent base model
valid_indices <- !is.na(meta$degree)  # Indices of documents with non-NA 'degree' values
filtered_docs <- docs[valid_indices]
filtered_meta <- meta[valid_indices, ]

#Model with full covariates
fit_full <- stm(documents = filtered_docs, vocab = vocab, K = 3, 
                prevalence = ~ ntreatment + female + degree + age + prof_man, max.em.its = 300, data = filtered_meta,
                init.type = "Spectral")
set.seed(12345)
prep_full <- estimateEffect(1:3 ~ ntreatment + age + female + degree + prof_man, fit_full,  meta = filtered_meta, uncertainty = "Global")
summary(prep_full)

## Figure F1 ----

# see "Combined Figures" above

## Figure F2 ----

# see "Combined Figures" above

## Table F1 ----
r1 = lm(data = data, replace_skill2 ~ ntreatment*under40, weights = weight)
r2 = lm(data = data, replace_ai2 ~ ntreatment*under40, weights = weight)
a1 = lm(data = data, ai_surveillance2 ~ ntreatment*under40, weights = weight)
a2 = lm(data = data, personal_benefit2 ~ ntreatment*under40, weights = weight)
a3 = lm(data = data, societal_benefit2 ~ ntreatment*under40, weights = weight)
a4 = lm(data = data, gov_limit2 ~ ntreatment*under40, weights = weight)
a5 = lm(data = data, ai_threat2 ~ ntreatment*under40, weights = weight)

stargazer(r1,r2,a1,a2,a3,a4,a5, type = "latex")

## Table F2 ----
s1 = lm(data = data, resp_training2 ~ ntreatment*under40, weights = weight)
s2 = lm(data = data, resp_job2 ~ ntreatment*under40, weights = weight)
s3 = lm(data = data, resp_benefits2 ~ ntreatment*under40, weights = weight)

stargazer(s1,s2,s3, type = "latex")

## Table F3 ----
r1 = lm(data = data, replace_skill2 ~ ntreatment*degree, weights = weight)
r2 = lm(data = data, replace_ai2 ~ ntreatment*degree, weights = weight)
a1 = lm(data = data, ai_surveillance2 ~ ntreatment*degree, weights = weight)
a2 = lm(data = data, personal_benefit2 ~ ntreatment*degree, weights = weight)
a3 = lm(data = data, societal_benefit2 ~ ntreatment*degree, weights = weight)
a4 = lm(data = data, gov_limit2 ~ ntreatment*degree, weights = weight)
a5 = lm(data = data, ai_threat2 ~ ntreatment*degree, weights = weight)

stargazer(r1,r2,a1,a2,a3,a4,a5, type = "latex")

## Table F4 ----
s1 = lm(data = data, resp_training2 ~ ntreatment*degree, weights = weight)
s2 = lm(data = data, resp_job2 ~ ntreatment*degree, weights = weight)
s3 = lm(data = data, resp_benefits2 ~ ntreatment*degree, weights = weight)

stargazer(s1,s2,s3, type = "latex")

## Table G1 ----
m1.2 = lm(data = subset(data, ntreatment == 1), replace_ai2 ~ age + female + degree + prof_man + as.factor(ai_helpful1_n), weights = weight)
m1.1 = lm(data = subset(data, ntreatment == 1), replace_skill2 ~ age + female + degree + prof_man + as.factor(ai_helpful1_n), weights = weight)
m2.2 = lm(data = subset(data, ntreatment == 1), replace_ai2 ~ age + female + degree + prof_man + as.factor(ai_helpful2_n), weights = weight)
m2.1 = lm(data = subset(data, ntreatment == 1), replace_skill2 ~ age + female + degree + prof_man + as.factor(ai_helpful2_n), weights = weight)
m3.2 = lm(data = subset(data, ntreatment == 1), replace_ai2 ~ age + female + degree + prof_man + as.factor(ai_helpful3_n), weights = weight)
m3.1 = lm(data = subset(data, ntreatment == 1), replace_skill2 ~ age + female + degree + prof_man + as.factor(ai_helpful3_n), weights = weight)

stargazer(m1.1,m1.2,m2.1,m2.2,m3.1,m3.2, type="latex")

## Table G2 ----
t1 = lm(data = subset(data, gpt_frequent != 1), replace_skill2 ~ treatment, weights = weight)
t2 = lm(data = subset(data, gpt_frequent != 1), replace_ai2 ~ treatment, weights = weight)
t3 = lm(data = subset(data, gpt_frequent != 1), ai_surveillance2 ~ treatment, weights = weight)
t4 = lm(data = subset(data, gpt_frequent != 1), personal_benefit2 ~ treatment, weights = weight)
t5 = lm(data = subset(data, gpt_frequent != 1), societal_benefit2 ~ treatment, weights = weight)
t6 = lm(data = subset(data, gpt_frequent != 1), gov_limit2 ~ treatment, weights = weight)
t7 = lm(data = subset(data, gpt_frequent != 1), ai_threat2 ~ treatment, weights = weight)

stargazer(t1,t2,t3,t4,t5,t6,t7, type = "latex")

## Table G3 ----
t1 = lm(data = subset(data, gpt_frequent != 1), resp_training2 ~ treatment, weights = weight)
t2 = lm(data = subset(data, gpt_frequent != 1), resp_job2 ~ treatment, weights = weight)
t3 = lm(data = subset(data, gpt_frequent != 1), resp_benefits2 ~ treatment, weights = weight)

stargazer(t1, t2, t3, type = "latex")

## Table G4 ----
t1 = lm(data = data, replace_skill2 ~ treatment + female + degree + age + pers_income, weights = weight)
t2 = lm(data = data, replace_ai2 ~ treatment + female + degree + age + pers_income, weights = weight)
t3 = lm(data = data, ai_surveillance2 ~ treatment + female + degree + age + pers_income, weights = weight)
t4 = lm(data = data, personal_benefit2 ~ treatment + female + degree + age + pers_income, weights = weight)
t5 = lm(data = data, societal_benefit2 ~ treatment + female + degree + age + pers_income, weights = weight)
t6 = lm(data = data, gov_limit2 ~ treatment + female + degree + age + pers_income, weights = weight)
t7 = lm(data = data, ai_threat2 ~ treatment + female + degree + age + pers_income, weights = weight)

stargazer(t1,t2,t3,t4,t5,t6,t7, type = "latex")

## Table G5 ----
t1 = lm(data = data, resp_training2 ~ treatment + female + degree + age + pers_income, weights = weight)
t2 = lm(data = data, resp_job2 ~ treatment + female + degree + age + pers_income, weights = weight)
t3 = lm(data = data, resp_benefits2 ~ treatment + female + degree + age + pers_income, weights = weight)

stargazer(t1,t2,t3, type = "latex")

## Table G6 ----
t1 = lm(data = subset(data, attention == 1), replace_skill2 ~ treatment, weights = weight)
t2 = lm(data = subset(data, attention == 1), replace_ai2 ~ treatment, weights = weight)
t3 = lm(data = subset(data, attention == 1), ai_surveillance2 ~ treatment, weights = weight)
t4 = lm(data = subset(data, attention == 1), personal_benefit2 ~ treatment, weights = weight)
t5 = lm(data = subset(data, attention == 1), societal_benefit2 ~ treatment, weights = weight)
t6 = lm(data = subset(data, attention == 1), gov_limit2 ~ treatment, weights = weight)
t7 = lm(data = subset(data, attention == 1), ai_threat2 ~ treatment, weights = weight)

stargazer(t1,t2,t3,t4,t5,t6,t7, type = "latex")

## Table G7 ----
t1 = lm(data = subset(data, attention == 1), resp_training2 ~ treatment, weights = weight)
t2 = lm(data = subset(data, attention == 1), resp_job2 ~ treatment, weights = weight)
t3 = lm(data = subset(data, attention == 1), resp_benefits2 ~ treatment, weights = weight)

stargazer(t1,t2,t3, type = "latex")

## Table G8 ----
t1_fs = lm(data = data, ai_tools ~ treatment, weights = weight)

stargazer(t1_fs, type="latex")

## Table G9 ----
t1_ss = ivreg(data = data, replace_skill2 ~ ai_tools | treatment, weights = weight)
t2_ss = ivreg(data = data, replace_ai2 ~ ai_tools | treatment, weights = weight)
t3_ss = ivreg(data = data, ai_surveillance2 ~ ai_tools | treatment, weights = weight)
t4_ss = ivreg(data = data, personal_benefit2 ~ ai_tools | treatment, weights = weight)
t5_ss = ivreg(data = data, societal_benefit2 ~ ai_tools | treatment, weights = weight)
t6_ss = ivreg(data = data, gov_limit2 ~ ai_tools | treatment, weights = weight)
t7_ss = ivreg(data = data, ai_threat2 ~ ai_tools | treatment, weights = weight)

stargazer(t1_ss,t2_ss,t3_ss,t4_ss,t5_ss,t6_ss,t7_ss, type="latex")

## Table G10 ----
t8_ss = ivreg(data = data, resp_training2 ~ ai_tools | treatment, weights = weight)
t9_ss = ivreg(data = data, resp_job2 ~ ai_tools | treatment, weights = weight)
t10_ss = ivreg(data = data,resp_benefits2 ~ ai_tools | treatment, weights = weight)

stargazer(t8_ss,t9_ss,t10_ss, type="latex")
